package base

import (
	"errors"
	"gitee.com/lsy007/mysqlclient"
	"gitee.com/lsy007/mysqlclient/param"
	"github.com/golang/protobuf/proto"
)

type JoinInfo struct {
	Type  string
	Table string
	Alias string
	On    string
}

func (m *MysqlHandle) JoinGet(info *GetInfo) (bool, error) {
	// 判断是否存在 join 条件
	if info.Join == nil {
		return false, errors.New("JoinCondition not allowed to be empty")
	}
	// 执行请求
	reply, err := m.SendJoinRequest(&mysqlclient.DBInfo{
		Table: m.Table, Func: "Get", Model: m.Model, AutoCondition: info.AutoCondition,
		Alias: info.Alias, Field: info.Field, Cols: info.Cols, Omit: info.Omit,
		Where: getWhere(info.Where), Join: buildBuildCondition(info.Join),
	})
	if err != nil || !reply.Has {
		return false, err
	}
	// 赋值
	err = proto.Unmarshal(reply.Model, info.RelatedResult)
	return reply.Has, err
}

func (m *MysqlHandle) JoinFind(info *FindInfo) (bool, error) {
	// 判断是否存在 join 条件
	if info.Join == nil {
		return false, errors.New("JoinCondition not allowed to be empty")
	}
	// 执行请求
	reply, err := m.SendJoinRequest(&mysqlclient.DBInfo{
		Table: m.Table, Func: "Find", Model: m.Model,
		Alias: info.Alias, Field: info.Field, Cols: info.Cols,
		Start: (info.Page - 1) * info.Rows, Rows: info.Rows, Order: info.Order,
		Where: getWhere(info.Where), Join: buildBuildCondition(info.Join),
	})
	if err != nil || !reply.Has {
		return false, err
	}
	// 赋值
	err = proto.Unmarshal(reply.Model, info.ListResult)
	return reply.Has, err
}

func (m *MysqlHandle) JoinFindAndCount(info *FindInfo) (int64, error) {
	// 判断是否存在 join 条件
	if info.Join == nil {
		return 0, errors.New("JoinCondition not allowed to be empty")
	}
	// 执行请求
	replyList, err := m.SendJoinRequest(&mysqlclient.DBInfo{
		Table: m.Table, Func: "Find", Model: m.Model,
		Alias: info.Alias, Field: info.Field, Cols: info.Cols,
		Start: (info.Page - 1) * info.Rows, Rows: info.Rows, Order: info.Order,
		Where: getWhere(info.Where), Join: buildBuildCondition(info.Join),
	})
	if err != nil || !replyList.Has {
		return 0, err
	}
	// 列表数据赋值
	if err = proto.Unmarshal(replyList.Model, info.ListResult); err != nil {
		return 0, err
	}
	// 获取总数量值
	replyCount, err := m.SendRequest(&mysqlclient.DBInfo{Table: m.Table, Func: "Count", Model: m.Model, Where: getWhere(info.Where)})
	if err != nil {
		return 0, err
	}
	return replyCount.Value.Int, nil
}

func buildBuildCondition(joinList []*JoinInfo) (join []*mysqlclient.JoinInfo) {
	if joinList == nil || len(joinList) == 0 {
		return
	}
	join = make([]*mysqlclient.JoinInfo, 0)
	for _, v := range joinList {
		join = append(join, &mysqlclient.JoinInfo{
			Type:  v.Type,
			Table: v.Table,
			Alias: v.Alias,
			On:    v.On,
		})
	}
	return join
}

func (m *MysqlHandle) SendJoinRequest(info *mysqlclient.DBInfo) (reply param.Response, err error) {
	client := m.Client
	client.RequestId = m.RequestId
	client.Region = m.Region
	client.SqlScenes = m.SqlScenes
	client.DBInfo = info
	return client.JoinRequest()
}
